#include <Common/error.h>
#include <Common/util.h>
#include <Compiler/compiler.h>
#include <Compiler/error_reporter.h>
#include <VM/object.h>

FunctionPtr Compiler::compile(std::string_view source, int line)
{
    scanner.initScanner(source, line);
    scope = std::make_unique<Scope>(FunctionType::None);
    advance();

    // implicitly append a Main function
    // user can't access it, since it's name is ""
    scope->locals.emplace_back(Token{TokenType::FUNC, "", 0, 0}, 0, true);

    while (!isAtEnd())
        parseDeclaration();

    emit(OpCode::RETURN, peek);

    return std::make_shared<FunctionObject>("MAIN", 0, std::move(scope->chunk));
}

void Compiler::parseDeclaration()
{
    try
    {
        if (peekIs(TokenType::VAR))
            parseVarDeclStmt();
        else if (peekIs(TokenType::FUNC))
        {
            parseFuncDeclStmt();
        }
        else
            parseStatement();
    }
    catch (const Error &e)
    {
        reporter.report(e);
        synchronize();
    }
}

void Compiler::parseFuncDeclStmt()
{
    Token tok = advance(); // "func"

    Token func_name = expect(TokenType::IDENTIFIER, "Expected function name");
    declareVariable(func_name);
    scope->locals.back().defined = true; // 允许函数递归调用

    parseFuncBody(FunctionType::Function, func_name);

    defineVariable(func_name);
}

void Compiler::parseFuncBody(FunctionType type, const Token &name)
{
    scope = std::make_unique<Scope>(type, std::move(scope));
    scope->locals.emplace_back(name, scope->scope_depth, true);
    scope->beginScope();

    expect(TokenType::LPAREN, "Expected '(' before parameter list");
    int arity = 0;
    if (!peekIs(TokenType::RPAREN))
    {
        do
        {
            Token param = expect(TokenType::IDENTIFIER, "Expected a parameter name");
            declareVariable(param);
            defineVariable(param);
            arity++;
        } while (advanceIf(TokenType::COMMA));
    }
    expect(TokenType::RPAREN, "Expected ')' after paramaeter list");

    expect(TokenType::LBRACE, "Expected '{' before func_body");
    parseBlock();

    scope->endScope();
    // 如果函数体没有return，则return nil
    // 如果函数有return，执行其return时就会清理掉
    emitReturn();

    std::unique_ptr<Chunk> chunk = std::move(scope->chunk);
    scope.reset(scope->enclosing.release());

    FunctionPtr func = std::make_shared<FunctionObject>(std::string(name.lexeme), arity, std::move(chunk));
    emitConstant(func, name);
}

void Compiler::parseVarDeclStmt()
{
    advance(); // 跳过"var"
    do
    {
        Token varName = expect(TokenType::IDENTIFIER, "Expected variable name");
        declareVariable(varName);

        if (advanceIf(TokenType::EQ))
        {
            parseAssign();
        }
        else
        {
            // 没有赋初值，那么就赋予Nil
            emit(OpCode::NIL, peek);
        }

        defineVariable(varName);
    } while (advanceIf(TokenType::COMMA));

    expect(TokenType::SEMICOLON, "Expected ';'");
}

void Compiler::declareVariable(const Token &name)
{
    if (scope->scope_depth == 0)
        emitConstant(std::string(name.lexeme), name);
    else
        scope->declareLocal(name);
}

void Compiler::defineVariable(const Token &name)
{
    if (scope->scope_depth > 0)
    {
        return scope->defineLocal();
    }

    emit(OpCode::DEFINE_GLOBAL, name);
}

void Compiler::parseStatement()
{
    switch (peek.type)
    {
    case TokenType::IF:
        return parseIfStmt();
    case TokenType::FOR:
        return parseForStmt();
    case TokenType::WHILE:
        return parseWhileStmt();
    case TokenType::CONTINUE:
        return parseContinueStmt();
    case TokenType::BREAK:
        return parseBreakStmt();
    case TokenType::RETURN:
        return parseReturnStmt();
    case TokenType::LBRACE:
    {
        advance(); // '{'
        scope->beginScope();
        parseBlock();
        scope->endScope();
        return;
    }

    default:
        return parseExpressionStmt();
    }
}

void Compiler::parseIfStmt()
{
    Token tok = advance(); // "If"
    expect(TokenType::LPAREN, "Expect '(' before condition");
    parseExpression();
    expect(TokenType::RPAREN, "Expect ')' after condition");

    // 如果condition不成立，跳过then
    int thenJump = emitJump(OpCode::JUMP_IF_FALSE, tok);
    emitPop(); // 弹出condition

    parseStatement(); // "then"

    // 如果执行了if的statement，则无条件跳过else
    // 换言之，我们总是认为if-else成对出现
    // 即使用户没有定义else
    int elseJump = emitJump(OpCode::JUMP, tok);

    patchJump(thenJump);
    emitPop(); // 弹出condition，如果thenJump执行

    if (advanceIf(TokenType::ELSE))
        parseStatement(); // "else"

    patchJump(elseJump);
}

void Compiler::parseForStmt()
{
    // for (var a = 0; a < 10; a = a + 1){}
    Token tok = advance(); // "for"
    scope->beginScope();

    expect(TokenType::LPAREN, "Expected '(' after for");
    if (advanceIf(TokenType::SEMICOLON))
    {
    }
    else if (peekIs(TokenType::VAR))
    {
        parseVarDeclStmt();
    }
    else
    {
        parseExpressionStmt();
    }

    int surroundingLoopStart = innerMostLoopStart;
    size_t surroundingLoopDepth = innerMostLoopDepth;
    innerMostLoopStart = scope->chunk->size();
    innerMostLoopDepth = scope->scope_depth;

    int exitJump = -1; // 如果没有condition，那么默认无跳出
    if (!advanceIf(TokenType::SEMICOLON))
    {
        parseExpression();
        expect(TokenType::SEMICOLON, "Expected ';' after condition");

        exitJump = emitJump(OpCode::JUMP_IF_FALSE, tok);
        emitPop();
    }

    // 在执行逻辑上，自增晚于Statement执行
    // 但是在出现顺序上，自增早于Statement定义
    // 因此我们需要在执行自增前Jump到Statement
    // 再在Statement执行后Jump到自增。
    if (!advanceIf(TokenType::RPAREN))
    {
        int bodyJump = emitJump(OpCode::JUMP, tok); // 跳到自增之后，即Statement
        int incrementStart = scope->chunk->size();

        parseExpression();
        emitPop();
        expect(TokenType::RPAREN, "Expected ';' to close up for clauses");

        emitLoop(innerMostLoopStart, peek); // 执行完自增后，Jump到condition
        innerMostLoopStart = incrementStart;
        patchJump(bodyJump);
    }

    parseStatement();
    emitLoop(innerMostLoopStart, peek); // 执行完Statement，跳到自增

    if (exitJump != -1)
    {
        patchJump(exitJump);
        emitPop();
    }

    innerMostLoopStart = surroundingLoopStart;
    innerMostLoopDepth = surroundingLoopDepth;

    // break处理
    for (size_t target : unpatchedBreaks)
        patchJump(target);
    unpatchedBreaks.clear();

    scope->endScope();
}

void Compiler::parseWhileStmt()
{
    Token tok = advance(); // "while"

    int surroundingLoopStart = innerMostLoopStart;
    size_t surroundingLoopDepth = innerMostLoopDepth;
    innerMostLoopStart = scope->chunk->size();
    innerMostLoopDepth = scope->scope_depth;

    expect(TokenType::LPAREN, "Expected '(' before condition");
    parseExpression();
    expect(TokenType::RPAREN, "Expected ')' after condition");

    size_t endJump = emitJump(OpCode::JUMP_IF_FALSE, tok);
    emitPop();

    parseStatement();
    emitLoop(innerMostLoopStart, peek); // Go Back to loopStart

    patchJump(endJump);
    emitPop();

    innerMostLoopStart = surroundingLoopStart;
    innerMostLoopDepth = surroundingLoopDepth;

    for (size_t target : unpatchedBreaks)
        patchJump(target);
    unpatchedBreaks.clear();
}

void Compiler::parseBreakStmt()
{
    Token tok = advance(); // "break";
    if (innerMostLoopDepth == 0)
        throw Error(tok, "'break' can only be used inside a loop");

    unpatchedBreaks.push_back(emitJump(OpCode::JUMP, tok));
    expect(TokenType::SEMICOLON, "Expected ';' after break");
}

void Compiler::parseContinueStmt()
{
    Token tok = advance(); // "continue";
    if (innerMostLoopStart == -1)
    {
        throw Error(tok, "'continue' can only be used inside a loop");
    }

    for (auto it = scope->locals.rbegin(); (it != scope->locals.rend()) && (it->depth > innerMostLoopDepth); it++)
    {
        emitPop();
    }

    emitLoop(innerMostLoopStart, tok);
    expect(TokenType::SEMICOLON, "Expected ';' after continue");
}

void Compiler::parseReturnStmt()
{
    Token tok = advance(); // "return"

    if (scope->type == FunctionType::None)
    {
        throw Error(tok, "Can't return outside of a function");
    }

    if (advanceIf(TokenType::SEMICOLON))
    {
        emitReturn();
    }
    else
    {
        parseTernary();
        expect(TokenType::SEMICOLON, "Expected ';' after return value");
        emit(OpCode::RETURN, tok);
    }
}

void Compiler::parseBlock()
{
    while (!isAtEnd() && !peekIs(TokenType::RBRACE))
    {
        parseDeclaration();
    }

    expect(TokenType::RBRACE, "Expected '}' after block");
}

void Compiler::parseExpressionStmt()
{
    parseExpression();
    expect(TokenType::SEMICOLON, "Expected ';'");
    emitPop();
}

void Compiler::parseExpression()
{
    parseComma();
}

void Compiler::parseComma()
{
    size_t count = 0;
    do
    {
        parseAssign();
        ++count;
    } while (advanceIf(TokenType::COMMA));

    // 必须确保执行完所有语句后，栈内清空
    // 由于parseExpressionStmt固定Pop一个，因此这里Pop剩余的
    while (count-- > 1)
    {
        emitPop();
    }
}

void Compiler::parseAssign()
{
    parseTernary();
    if (!peekIs(TokenType::EQ))
        return;

    Token op = advance(); // 为之后允许，+=、-=...做准备

    // 有效的赋值语句：xxx[1]().yyy = 1，'='之前一定有一个Get语句
    if (!delayGet)
        throw Error(op, "Invalid left-hand side of assignment");

    auto opCode = delayGet->op;
    auto argument = delayGet->arg;
    delayGet.reset();

    parseAssign();
    if (opCode == OpCode::GET_GLOBAL)
        emit(OpCode::SET_GLOBAL, op, argument);
    else
        emit(OpCode::SET_LOCAL, op, argument);
}

void Compiler::parseTernary()
{
    parseOr();
    if (!peekIs(TokenType::QUESTION_MARK))
        return;

    Token tok = advance(); // '?'
    size_t elseJump = emitJump(OpCode::JUMP_IF_FALSE, tok);

    emitPop();
    parseAssign();
    size_t endJump = emitJump(OpCode::JUMP, tok);

    patchJump(elseJump);
    emitPop();

    expect(TokenType::COLON, "Expected ':' for ternary else expression");
    parseAssign();

    patchJump(endJump);
}

void Compiler::parseOr()
{
    parseAnd();
    if (!peekIs(TokenType::OR))
        return;

    Token tok = advance(); // "or"
    // or遵循一真则真，所以有一个为真，直接跳到最后
    size_t endJump = emitJump(OpCode::JUMP_IF_TRUE, tok);

    // 如果当前不是真，则弹出当前值，继续看下一个
    emitPop();
    parseAnd();

    patchJump(endJump);
}

void Compiler::parseAnd()
{
    parseEquality();
    if (!peekIs(TokenType::AND))
        return;

    Token tok = advance(); // "and"
    // And遵循一假则假，所以有一个为否，直接跳到最后
    size_t endJump = emitJump(OpCode::JUMP_IF_FALSE, tok);

    // 如果当前不是否，则弹出当前值，继续看下一个
    emitPop();
    parseAnd();

    patchJump(endJump);
}

void Compiler::parseEquality()
{
    static const std::unordered_map<TokenType, OpCode> ops{
        {TokenType::EQEQ, OpCode::EQEQ},
        {TokenType::BANGEQ, OpCode::NEQ}};

    parseBinary(&Compiler::parseComparison, ops);
}

void Compiler::parseComparison()
{
    static const std::unordered_map<TokenType, OpCode> ops{
        {TokenType::GT, OpCode::GT},
        {TokenType::GTE, OpCode::GTE},
        {TokenType::LT, OpCode::LT},
        {TokenType::LTE, OpCode::LTE}};

    parseBinary(&Compiler::parseTerm, ops);
}

void Compiler::parseTerm()
{
    static const std::unordered_map<TokenType, OpCode> ops{
        {TokenType::PLUS, OpCode::ADD},
        {TokenType::MINUS, OpCode::SUBTRACT}};
    parseBinary(&Compiler::parseFactor, ops);
}

void Compiler::parseFactor()
{
    static const std::unordered_map<TokenType, OpCode> ops{
        {TokenType::MUL, OpCode::MULTIPLY},
        {TokenType::DIV, OpCode::DIVIDE}};
    parseBinary(&Compiler::parseUnary, ops);
}

void Compiler::parseUnary()
{
    static std::unordered_map<TokenType, OpCode> ops{
        {TokenType::MINUS, OpCode::NEGATE},
        {TokenType::BANG, OpCode::NOT}};

    auto op = ops.find(peek.type);
    if (op == ops.end())
        return parsePrefix();

    Token tok = advance();
    parseUnary();
    emit(op->second, tok);
}

void Compiler::parsePrefix()
{
    if (peekIs(TokenType::PLUS_PLUS, TokenType::MINUS_MINUS))
    {
        Token op_tok = advance();
        parseCall();

        if (!delayGet)
            throw Error(op_tok, "Can only '" + std::string(op_tok.lexeme) + "' a variable");

        OpCode op = op_tok.type == TokenType::PLUS_PLUS ? OpCode::ADD : OpCode::SUBTRACT;
        return handleIDCrement(op);
    }

    return parsePostfix();
}

void Compiler::parsePostfix()
{
    parseCall();

    if (peekIs(TokenType::PLUS_PLUS, TokenType::MINUS_MINUS))
    {
        Token op_tok = advance();
        if (!delayGet)
            throw Error(op_tok, "Can only '" + std::string(op_tok.lexeme) + "' a variable");

        OpCode op = op_tok.type == TokenType::PLUS_PLUS ? OpCode::ADD : OpCode::SUBTRACT;
        return handleIDCrement(op, true);
    }
}

void Compiler::parseCall()
{
    parsePrimary();

    while (true)
    {
        if (peekIs(TokenType::LPAREN))
        {
            Token tok = advance();
            u8_t args_num = parseArgument(TokenType::RPAREN);
            emit(OpCode::CALL, tok, args_num);
        }
        else
            break;
    }
}

u8_t Compiler::parseArgument(TokenType ending)
{
    u8_t args = 0;
    if (!peekIs(ending))
    {
        do
        {
            parseTernary();
            args++;
        } while (advanceIf(TokenType::COMMA));
    }
    expect(ending, format("Expected '%s' after arguments", Token::TypeSymbol(ending)));

    return args;
}

void Compiler::parsePrimary()
{
    switch (peek.type)
    {
    case TokenType::LPAREN:
        parseGroup();
        return;

    case TokenType::IDENTIFIER:
        parseIdentifier();
        return;

    case TokenType::STRING:
        parseString();
        return;

    case TokenType::NUMBER:
        parseNumber();
        return;

    case TokenType::NIL:
        emit(OpCode::NIL, advance());
        return;

    case TokenType::TRUE:
        emit(OpCode::TRUE, advance());
        return;

    case TokenType::FALSE:
        emit(OpCode::FALSE, advance());
        return;

    default:
        throw Error(peek, isAtEnd() ? "Unexpected end of input." : "Unexpected token '" + std::string(peek.lexeme) + "'.");
    }
}

void Compiler::parseIdentifier()
{
    Token identifier = advance();
    if (delayGet)
        emitDelayGet();

    int arg = scope->resolveLocal(identifier);
    if (arg != -1)
    {
        delayGet = {OpCode::GET_LOCAL, identifier, static_cast<u8_t>(arg)};
        return;
    }

    size_t index = emitConstant(std::string(identifier.lexeme), identifier);
    // 暂存Get语句，如果之后是assignment，则会发射
    delayGet = {OpCode::GET_GLOBAL, identifier, index};
}

void Compiler::parseString()
{
    // 去掉首尾的'"'
    std::string str = std::string(peek.lexeme.cbegin() + 1, peek.lexeme.cend() - 1);
    Token tok = advance();
    emitConstant(str, tok);
}

void Compiler::parseGroup()
{
    advance();
    parseExpression();
    expect(TokenType::RPAREN, "Expected ')' to close up");
}

void Compiler::parseNumber()
{
    double num = std::stod(std::string(peek.lexeme));
    Token tok = advance();
    emitConstant(num, tok);
}

void Compiler::parseBinary(const std::function<void(Compiler *)> &parseOperand, const std::unordered_map<TokenType, OpCode> &ops)
{
    parseOperand(this);

    for (;;)
    {
        auto op = ops.find(peek.type);
        if (op == ops.end())
            return;

        Token tok = advance();
        parseOperand(this);
        emit(op->second, tok);
    }
}

void Compiler::handleIDCrement(OpCode op, bool postfix)
{
    // 手动将++a，解释为a=a+1
    // 或a++，解释为a，a=a+1
    OpCode setCode = delayGet->op == OpCode::GET_GLOBAL ? OpCode::SET_GLOBAL : OpCode::SET_LOCAL;
    bool global = setCode == OpCode::SET_GLOBAL;

    u8_t index = delayGet->arg.value();
    Value identifier = scope->chunk->getConstant(index);

    // postfix 需要返回原始值，所以我们需要先Get一次
    if (postfix)
    {
        auto tmp = delayGet;

        if (global)
        {
            emitDelayGet();
            delayGet = tmp;
            scope->chunk->write(OpCode::CONSTANT, delayGet->tok.line);
            scope->chunk->write(scope->chunk->addConstant(identifier), delayGet->tok.line);
        }

        emitDelayGet();
        delayGet = tmp;
    }

    if (global)
    {
        // write a copy of identifier
        scope->chunk->write(OpCode::CONSTANT, delayGet->tok.line);
        scope->chunk->write(scope->chunk->addConstant(identifier), delayGet->tok.line);
    }

    emitConstant(1.0, peek);    // Note: this will emitDelayGet first
    emit(op, peek);             // Add or Sub
    emit(setCode, peek, index); // set value

    if (postfix)
        emitPop();
}

void Compiler::reset()
{
    scope.reset();
    unpatchedBreaks.clear();
    delayGet.reset();
    innerMostLoopDepth = 0;
    innerMostLoopStart = -1;
}

bool Compiler::isAtEnd()
{
    return peekIs(TokenType::END_OF_FILE);
}

bool Compiler::peekIs(TokenType type)
{
    return peek.type == type;
}

Token Compiler::advance()
{
    Token tok = peek;

    while (true)
    {
        peek = scanner.scanToken();
        if (!peekIs(TokenType::ERROR))
            return tok;

        reporter.report(peek.line, peek.column, peek.lexeme.data());
    }
}

bool Compiler::advanceIf(TokenType type)
{
    bool matches = peekIs(type);
    if (matches)
        advance();

    return matches;
}

Token Compiler::expect(TokenType type, std::string errorMessage)
{
    if (!peekIs(type))
        throw Error(peek, std::move(errorMessage));

    return advance();
}

void Compiler::synchronize()
{
    for (; !isAtEnd();)
    {
        advance();

        if (advanceIf(TokenType::SEMICOLON))
            return;

        switch (peek.type)
        {
        case TokenType::VAR:
        case TokenType::IF:
        case TokenType::WHILE:
        case TokenType::FOR:
        case TokenType::FUNC:
        case TokenType::CLASS:
        case TokenType::RETURN:
            return;
        default:
            break;
        }
    }
}

void Compiler::emit(OpCode code, const Token &tok, std::optional<u8_t> argument)
{
    if (delayGet)
    {
        emitDelayGet();
    }

    scope->chunk->write(code, tok.line);
    if (argument)
        scope->chunk->write(argument.value(), tok.line);
}

size_t Compiler::emitConstant(Value val, const Token &tok)
{
    size_t index = scope->chunk->addConstant(std::move(val));
    // TODO should check stack overflow

    emit(OpCode::CONSTANT, tok, index);

    return index;
}

void Compiler::emitDelayGet()
{
    scope->chunk->write(delayGet->op, delayGet->tok.line);
    if (delayGet->arg)
        scope->chunk->write(delayGet->arg.value(), delayGet->tok.line);

    delayGet.reset();
}

void Compiler::emitPop()
{
    emit(OpCode::POP, peek);
}

size_t Compiler::emitJump(OpCode code, const Token &tok)
{
    // 发射跳转指令，跳转位置尚未知晓
    // 需要之后Patch(回填法)
    emit(code, tok, static_cast<u8_t>(0xff));
    // 地址给双字节(0xffff)，这样有最高65535个跳转可能
    scope->chunk->write(static_cast<u8_t>(0xff), tok.line);

    // 后两个字节的地址是我们任意填写的，之后要回填
    // 所以此处返回高位地址的位置
    return scope->chunk->size() - 2;
}

void Compiler::patchJump(int offset)
{
    // 跳跃距离
    int jump = scope->chunk->size() - 2 - offset;

    scope->chunk->setCode(offset, (jump >> 8) & 0xff); // 高位
    scope->chunk->setCode(offset + 1, jump & 0xff);    // 低位
}

void Compiler::emitLoop(int loopStart, const Token &tok)
{
    emit(OpCode::LOOP, tok);

    int jump = scope->chunk->size() - loopStart + 2;

    scope->chunk->write((jump >> 8) & 0xff, tok.line);
    scope->chunk->write(jump & 0xff, tok.line);
}

void Compiler::emitReturn()
{
    emit(OpCode::NIL, peek);
    emit(OpCode::RETURN, peek);
}
